{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle, yaml\n",
    "from pathlib import Path\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "PROJECT_DIRECTORY = Path(os.path.abspath(\"\")).parent\n",
    "SAVE_FIG_PATH = PROJECT_DIRECTORY / \"_static\"\n",
    "RESULTS_PATH = PROJECT_DIRECTORY / \"results\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def saveFig(name, fig):\n",
    "    fig.savefig(\n",
    "        name,\n",
    "        dpi=None,\n",
    "        facecolor=fig.get_facecolor(),\n",
    "        edgecolor=\"none\",\n",
    "        orientation=\"portrait\",\n",
    "        format=\"png\",\n",
    "        transparent=False,\n",
    "        bbox_inches=\"tight\",\n",
    "        pad_inches=0.2,\n",
    "        metadata=None,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_pickle(path_to_pickle):\n",
    "    with open(path_to_pickle, \"rb\") as handle:\n",
    "        data = pickle.load(handle)\n",
    "    return data\n",
    "\n",
    "def read_config(config_directory):\n",
    "    with open(config_directory / \"config.json\", \"r\") as file:\n",
    "        config = yaml.safe_load(file)\n",
    "    return config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fuse_by_dataset(losses):\n",
    "    \"\"\"Transform per-round history (list of dicts) into\n",
    "\n",
    "    a single dict, with values as lists.\"\"\"\n",
    "    fussed_losses = {}\n",
    "\n",
    "    for _, loss_dict in losses:\n",
    "        for k, v in loss_dict.items():\n",
    "            if k in fussed_losses:\n",
    "                fussed_losses[k].append(v)\n",
    "            else:\n",
    "                fussed_losses[k] = [v]\n",
    "    return fussed_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_multirun_data(path_multirun):\n",
    "    \"\"\"Given a path to a multirun directory, this loads the history of all runs\"\"\"\n",
    "    res_list = []\n",
    "    for results in list(Path(path_multirun).glob(\"**/*.pkl\")):\n",
    "        config = read_config(Path(results).parent)\n",
    "        data = read_pickle(results)\n",
    "        pre_train_loss = data[\"history\"].metrics_distributed_fit[\"pre_train_losses\"]\n",
    "        fussed_losses = fuse_by_dataset(pre_train_loss)\n",
    "        res_list.append(\n",
    "            {\n",
    "                \"strategy\": config[\"algorithm-name\"],\n",
    "                \"train_losses\": fussed_losses,\n",
    "            }\n",
    "        )\n",
    "    return res_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_losses = process_multirun_data(RESULTS_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def average_by_client_type(all_fused_lossed):\n",
    "    \"\"\"If there are multliple runs for the same strategy add them up,\n",
    "\n",
    "    average them later. This is useful if you run the `--multirun` running\n",
    "    more than one time the same configuration.\"\"\"\n",
    "\n",
    "    # identify how many unique clients were used\n",
    "    to_plot = {}\n",
    "    for res in all_fused_lossed:\n",
    "        strategy = res[\"strategy\"]\n",
    "        if strategy not in to_plot:\n",
    "            to_plot[strategy] = {}\n",
    "\n",
    "        for dataset, train_loss in res[\"train_losses\"].items():\n",
    "            if dataset in to_plot[strategy]:\n",
    "                to_plot[strategy][dataset][\"train_loss\"] += np.array(train_loss)\n",
    "                to_plot[strategy][dataset][\"run_count\"] += 1\n",
    "            else:\n",
    "                to_plot[strategy][dataset] = {\"train_loss\": np.array(train_loss)}\n",
    "                to_plot[strategy][dataset][\"run_count\"] = 1\n",
    "\n",
    "    # print(to_plot)\n",
    "    return to_plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "to_plot = average_by_client_type(all_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = to_plot[list(to_plot.keys())[0]].keys()\n",
    "print(datasets)\n",
    "\n",
    "num_datasets = len(datasets)\n",
    "fig, axs = plt.subplots(figsize=(14, 4), nrows=1, ncols=num_datasets)\n",
    "\n",
    "\n",
    "for s_id, (strategy, results) in enumerate(to_plot.items()):\n",
    "    for i, dataset in enumerate(datasets):\n",
    "        loss = results[dataset][\"train_loss\"] / results[dataset][\"run_count\"]\n",
    "        axs[i].plot(range(len(loss)), loss, label=strategy)\n",
    "        axs[i].set_xlabel(\"Round\")\n",
    "        if i == 0:\n",
    "            axs[i].set_ylabel(\"Train Loss\")\n",
    "\n",
    "        axs[i].legend()\n",
    "\n",
    "        if s_id == 0:\n",
    "            axs[i].grid()\n",
    "            axs[i].set_title(dataset)\n",
    "            axs[i].set_xticks(np.arange(0, 100 + 1, 25))\n",
    "\n",
    "\n",
    "saveFig(SAVE_FIG_PATH / \"train_loss.png\", fig)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fedbn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
